{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#hide\n",
    "#skip\n",
    "! [ -e /content ] && pip install -Uqq fastai  # upgrade fastai on colab"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#default_exp callback.mixup"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "from fastai.basics import *\n",
    "from fastai.callback.progress import *\n",
    "from fastai.vision.core import *\n",
    "from fastai.vision.models.xresnet import *\n",
    "\n",
    "from torch.distributions.beta import Beta"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#hide\n",
    "from nbdev.showdoc import *\n",
    "from fastai.test_utils import *"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Mixup callback\n",
    "\n",
    "> Callback to apply MixUp data augmentation to your training"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## MixupCallback -"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# export\n",
    "def reduce_loss(loss, reduction='mean'):\n",
    "    return loss.mean() if reduction=='mean' else loss.sum() if reduction=='sum' else loss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# export\n",
    "class MixUp(Callback):\n",
    "    run_after,run_valid = [Normalize],False\n",
    "    def __init__(self, alpha=0.4): self.distrib = Beta(tensor(alpha), tensor(alpha))\n",
    "    def before_fit(self):\n",
    "        self.stack_y = getattr(self.learn.loss_func, 'y_int', False)\n",
    "        if self.stack_y: self.old_lf,self.learn.loss_func = self.learn.loss_func,self.lf\n",
    "\n",
    "    def after_fit(self):\n",
    "        if self.stack_y: self.learn.loss_func = self.old_lf\n",
    "\n",
    "    def before_batch(self):\n",
    "        lam = self.distrib.sample((self.y.size(0),)).squeeze().to(self.x.device)\n",
    "        lam = torch.stack([lam, 1-lam], 1)\n",
    "        self.lam = lam.max(1)[0]\n",
    "        shuffle = torch.randperm(self.y.size(0)).to(self.x.device)\n",
    "        xb1,self.yb1 = tuple(L(self.xb).itemgot(shuffle)),tuple(L(self.yb).itemgot(shuffle))\n",
    "        nx_dims = len(self.x.size())\n",
    "        self.learn.xb = tuple(L(xb1,self.xb).map_zip(torch.lerp,weight=unsqueeze(self.lam, n=nx_dims-1)))\n",
    "        if not self.stack_y:\n",
    "            ny_dims = len(self.y.size())\n",
    "            self.learn.yb = tuple(L(self.yb1,self.yb).map_zip(torch.lerp,weight=unsqueeze(self.lam, n=ny_dims-1)))\n",
    "\n",
    "    def lf(self, pred, *yb):\n",
    "        if not self.training: return self.old_lf(pred, *yb)\n",
    "        with NoneReduce(self.old_lf) as lf:\n",
    "            loss = torch.lerp(lf(pred,*self.yb1), lf(pred,*yb), self.lam)\n",
    "        return reduce_loss(loss, getattr(self.old_lf, 'reduction', 'mean'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from fastai.vision.core import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "path = untar_data(URLs.MNIST_TINY)\n",
    "items = get_image_files(path)\n",
    "tds = Datasets(items, [PILImageBW.create, [parent_label, Categorize()]], splits=GrandparentSplitter()(items))\n",
    "dls = tds.dataloaders(after_item=[ToTensor(), IntToFloatTensor()])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: left;\">\n",
       "      <th>epoch</th>\n",
       "      <th>train_loss</th>\n",
       "      <th>valid_loss</th>\n",
       "      <th>time</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <td>0</td>\n",
       "      <td>00:00</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAgAAAAIHCAYAAADpfeRCAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAgAElEQVR4nO3deXCV1f3H8fMAgUASCIGwhLDINiBUFhWsIoJIq4i2gKLiVim4LxUrtopjLdoZRZ1a1P5EreLS2gGxdcRWkCpKUbQoO4qyKPsmEgLZSO7vD2qn5fu9cnKfJPc+9/t+zTjTfniWk3Dy5Ovj954TxGIxBwAAbKmX7AEAAIC6RwEAAIBBFAAAABhEAQAAgEEUAAAAGEQBAACAQRQAAAAYRAFQTUEQFB/1T2UQBNOTPS6gOpjHiDrmcHgNkj2AqInFYtnf/u8gCLKcczudc7OSNyKg+pjHiDrmcHi8AQjnAufcLufce8keCBAC8xhRxxxOAAVAOFc6556PsZ4yoo15jKhjDicg4PuVmCAIOjjnNjrnusZisY3JHg+QCOYxoo45nDjeACTuCufcIiYcIo55jKhjDieIAiBxVzjnZiZ7EEBIzGNEHXM4QfwngAQEQXCqc26+c65NLBY7kOzxAIlgHiPqmMPh8AYgMVc65+Yw4RBxzGNEHXM4BN4AAABgEG8AAAAwiAIAAACDKAAAADCIAgAAAIMoAAAAMOhYuwHyEQGEESR7AI45jHBSYQ47xzxGOOo85g0AAAAGUQAAAGAQBQAAAAZRAAAAYBAFAAAABlEAAABgEAUAAAAGUQAAAGAQBQAAAAZRAAAAYBAFAAAABlEAAABgEAUAAAAGUQAAAGAQBQAAAAZRAAAAYBAFAAAABlEAAABgEAUAAAAGUQAAAGAQBQAAAAZRAAAAYBAFAAAABlEAAABgEAUAAAAGUQAAAGAQBQAAAAZRAAAAYBAFAAAABjVI9gDSTSwWE9nhw4dFVlFRoZ4fLz9ao0aNvLIgCLyuBwCa8vJykVVVVYmsfv36Xtdr0KDmf+1ozzntWczz8H/xBgAAAIMoAAAAMIgCAAAAgygAAAAwiCZATwcPHhTZ+vXrRTZnzhyRTZ06VWRaE41zepPKZZddJrJf//rXIuvQoYPX9YDaVFxc7HXcP/7xD5F99NFHInvxxRdFNm7cOJHdf//9XvdFfFrjnPas8m121vgeVx1aA2JlZaXI6tWT/85bnWdkw4YNqzewFMcbAAAADKIAAADAIAoAAAAMogAAAMCgtGwC1BpUDh06pB6rNffNmjVLZPfdd5/Idu7cKbKMjAyR5ebmiqx9+/bqeL755huRaU1Qt912m8ho+EsebbW03bt3e52bn58vsmQ2G2lfi5Zt375dPV9rhNW89tprIlu8eLHXuVqDF8LTniF/+ctfRKY9kyZPniyydu3aiaygoCDB0cXnOx+2bNnilTnn3IIFC0R2xx13iCwrK8vr3qmINwAAABhEAQAAgEEUAAAAGEQBAACAQZFqAiwtLRXZhg0bRPbwww+LbNGiReo1P//8c5FpjTDaVrs9evQQ2cknnyyyW2+9VWTNmjVTxzN8+HA1P5q2pSZNgMmjNcn58m0WjDdntKa7PXv2iExryGrTpo3IunbtKrIlS5aIbOnSpep4fvOb36h5orTV2yZOnFij90B82ip7y5YtE9m5554rMm0VwdNPP9373tr26L6r/mkWLlwosnirsmqGDh0qsiFDhogsKs9i3gAAAGAQBQAAAAZRAAAAYBAFAAAABqVEE2BZWZnIHnnkEZHNnTtXZL4rh2VmZqp5586dRXbOOeeI7PLLLxeZ1iyVl5fnNZ4DBw6ouda8pYny6lOWbd26VWTaFriPPfaYyPbt26dec+/evSLTtlzVVqnUVhzMzs72uneYxsd4tDHOnj1bZF26dKnxe0N34YUXimzNmjUie/3110W2adMmkX3yySfqfbRmPK0B1Jc2P6vTnFdYWCgybd5p49YaFVMRbwAAADCIAgAAAIMoAAAAMIgCAAAAgygAAAAwKCU+BfD73/9eZFoXtLb/uNYNry2ZOmzYMPXekyZNElmfPn3UY2vSrFmz1LyoqEhk2jLEydwvHtLq1atFNm7cOJFp3fTxuvtrmrasqpYdPHjQ63ra0sLO6XM43qdejqZ9CmDQoEFe56Lu3HPPPSKbPHmyyLZs2eJ9zc8++0xk2pLny5cvF5n2qZexY8eK7KuvvhJZvOXXb7rpJpG1atVKZFHp+NfwBgAAAIMoAAAAMIgCAAAAgygAAAAwKCWaADt06CCyxo0bi6x9+/YiO+2000Q2ePBgkY0fP169d10002l7Va9du9b7/DFjxoisdevWocaExG3YsEFk559/vsh27drldb2cnByRNW3aVGQjRoxQz9eaVrW92Zs3b+41Hm2Z1+OOO05kWiOYc87NmDHD6z79+vUT2XXXXScybWlipB7tmd2tWzfv832PPfvss72vebTqPO8vueQSkVVnKeEo4A0AAAAGUQAAAGAQBQAAAAZRAAAAYFBKNAGOGjVKZOedd57Xudp+0am2MlNZWZnI3n//ffVYreFpypQpIku1r9ESbd/zPXv2JHy9vn37imzIkCEi+8UvfqGe36RJk4Tvrenfv7/IHnzwQZHNnTvX+5qZmZki05oITzrpJJGx6iUSoTVfv/DCCyI744wz1PNzc3NFpq1UGWW8AQAAwCAKAAAADKIAAADAIAoAAAAMSokmQG11pag2W2iNJ//85z9FtmTJEvV8bRXDeNuuIjm0rX+rqqpEpq1cecIJJ4jsiSeeEJn2d14bjZ9ag+qHH34osvnz53tfU1sR7pFHHhGZ1vCnnQskQnsWL1u2TGTxnq/az5vWdB5l6fXVAAAALxQAAAAYRAEAAIBBFAAAABiUEk2AUaU1fpWUlIjs7rvvFlm87Xzbtm0rMm27WNS+4uJiNb/xxhtFVlhYKDJthb5rrrlGZMnc7lbb+veuu+7yOreoqEjNtS26fVf2ZOtf1BStCXDBggUimzBhgnp+ujX8adL/KwQAAAIFAAAABlEAAABgEAUAAAAG0QQYQkVFhci0xhNt5bj8/Hz1mtrWyEgteXl5Ihs/frzXuclscps5c6bIXn31Va9z9+3bJ7KmTZuqx/7ud7/zuqa26h9b/yIR2nN3165dItNWvhw9erR6TQtzkTcAAAAYRAEAAIBBFAAAABhEAQAAgEE0AXo6fPiwyEpLS0Wmbf2rNY116tRJvc+IESOqPzjUivLy8mQPIWGfffaZyLSGv6+//trregcPHhTZz3/+c/XYeM2BR7PQZIW64bvqX4MG8leetnqrFbwBAADAIAoAAAAMogAAAMAgCgAAAAyiAAAAwCA+BaDQOkq1jv8gCET229/+VmSNGjUS2fXXX6/eW1seFckRb9lerWtY+3urq2V/tU8rPPzwwyLbs2ePyLQ9z1u0aCGyU089VWQXXnih7xBVyVwWGemlqqpKZLNmzRJZcXGxyJo0aVIrY4oC3gAAAGAQBQAAAAZRAAAAYBAFAAAABtEEqKioqBCZ1hiYkZEhMm2/aU3Lli2rPzDUqXhL1bZr166OR3LEgQMH1HzlypUi+/TTT0WmNfxpDVA33HCDyHr27OkzxLjy8/NDnQ98KxaLiWzHjh0ie/PNN0XWsWNHkfXv31+9j7b8u7aUsNYMHhW8AQAAwCAKAAAADKIAAADAIAoAAAAMMt8EqDV6aCuraQ1US5cuFdn27du97qs1kwDfOnjwoMi05jznnNu0aZPXNTMzM0V27bXXiixsw58mXkMlUF1ak7a2EqDm0ksv9T5Xm7NRbvjT8AYAAACDKAAAADCIAgAAAIMoAAAAMMhMJ5rW7Oec/za/WkOIb2NTVlaWyAYMGOB1LtLfF198IbKbb75ZZNpWpvFoW1BPnTpVZIWFhd7X9NWsWbMavyZs0hr0tOyTTz7xut7o0aO9751uDX8a3gAAAGAQBQAAAAZRAAAAYBAFAAAABqVlE6C2XaS2epRz/tv8ak1VvrRV3RYvXqweO2TIkITvg9SnNZ3eeeedIqtOw19eXp7XNbt27SqykpIS7/scrXHjxmqenZ2d8DWB/6Y9y7XslVde8TpO+xmI92ynCRAAAKQlCgAAAAyiAAAAwCAKAAAADErLJkCt0Urb4rc6fLebBL6LtoV0dRr+NJMmTRJZ3759RbZ79+5Q9zkazX6obVqT9jvvvCOyl19+WWTaFu4WtvitDt4AAABgEAUAAAAGUQAAAGAQBQAAAAZFvgmwrKxMZNVp+GvQQH4LMjMzRaY1lADV1b17d5Hdfvvtoa6pbS2t/QzU9Da9vtthA4nSVvPTVlbVjBo1SmTM2f/FbzUAAAyiAAAAwCAKAAAADKIAAADAIAoAAAAMitSnALSOUK3b2XcJSOecy8jI8DpfU1BQILKsrCyR+XatIv3l5+eLbOjQoUkYCZD6fJ/F2ie3unXrlvD1rOC7AQCAQRQAAAAYRAEAAIBBFAAAABgUqSZAbd/mnJycJIzkiPbt24ts9uzZIhsxYkRdDAcA0orWpK010mq/ByZOnFgrY0onvAEAAMAgCgAAAAyiAAAAwCAKAAAADAq01fX+y3f+IXAMsmuz7jGHEUYqzGHnmMf/UVpaKrLp06eLbMKECSJr3rx5rYwpAtR5zBsAAAAMogAAAMAgCgAAAAyiAAAAwKBjNQECAIA0xBsAAAAMogAAAMAgCgAAAAyiAAAAwCAKAAAADKIAAADAIAoAAAAMogAAAMAgCgAAAAyiAAAAwCAKAAAADKIAAADAIAoAAAAMogAAAMAgCoAEBEHwYhAE24MgKAqCYF0QBBOSPSagOoIgKD7qn8ogCKYne1yAL+ZweEEsFkv2GCInCIJezrkvYrFYWRAEPZxz7zjnzo3FYkuTOzKg+oIgyHLO7XTOjYjFYu8mezxAdTGHE8MbgATEYrHVsVis7Nv/++9/uiRxSEAYFzjndjnn3kv2QIAEMYcTQAGQoCAIngiC4JBz7lPn3Hbn3BtJHhKQqCudc8/HeB2I6GIOJ4D/BBBCEAT1nXPfd84Ncc49EIvFKpI7IqB6giDo4Jzb6JzrGovFNiZ7PEB1MYcTxxuAEGKxWGUsFlvknCt0zl2X7PEACbjCObeIBycijDmcIAqAmtHA0QOAaLrCOTcz2YMAQmAOJ4gCoJqCIGgVBMHFQRBkB0FQPwiCHzrnLnHO/SPZYwOqIwiCU51z7Zxzs5I9FiARzOFwGiR7ABEUc0de9/+fO1JAfemc+1ksFvtrUkcFVN+Vzrk5sVjsQLIHAiSIORwCTYAAABjEfwIAAMAgCgAAAAyiAAAAwCAKAAAADDrWpwDoEEQYQbIH4JjDCCcV5rBzzGOEo85j3gAAAGAQBQAAAAZRAAAAYBAFAAAABlEAAABgEAUAAAAGUQAAAGAQBQAAAAZRAAAAYBAFAAAABlEAAABgEAUAAAAGUQAAAGAQBQAAAAZRAAAAYBAFAAAABlEAAABgEAUAAAAGUQAAAGAQBQAAAAZRAAAAYBAFAAAABlEAAABgEAUAAAAGUQAAAGAQBQAAAAZRAAAAYBAFAAAABlEAAABgEAUAAAAGNUj2AFLRZ599JrIFCxaI7O233xZZEASh7n3GGWeI7IorrhBZTk5OqPugZr377rsiy8zMFFn37t1Flp2dLbIGDcL9aFZUVIisvLxcZPXr1xeZNoe1cxs2bKjeOyMjQ2T16vHvGqg9Bw4cENn69etF9uc//1lkkyZNUq+Zn58ffmApjp9KAAAMogAAAMAgCgAAAAyiAAAAwKAgFot9159/5x9GidbE5JxzEydOFNmbb74psrKyMpFp37uwTYDaNVu0aCGyxx57TGTDhg0TWbxGrToS7ptRM+pkDs+aNUtkq1at8jpXawJs3bp1qPFs27ZNZCUlJV7nhp3X2tjPO+88kbVv3977mkmUCnPYuTR6FsejPWMXL14sss2bN4tMa8LVrrdu3TqR9enTRx3PHXfcIbLt27eLrGXLliLr1KmTyLSm4DqkzmPeAAAAYBAFAAAABlEAAABgEAUAAAAGmWkCHDlypJovWrTI63ytidC3UUtbbS2er7/+WmT79u0TWc+ePUX2/vvvi6xZs2be964FqdBAVSdzWPs5+uKLL0S2YsUKr+vt2rVLZDt27FCP1Zrpmjdv7nWfMLZu3arme/fuFVnHjh1FNn78+BofUy1IhTnsXBKfxcXFxV7HaStAaj8Xf//739Xzp0+fLjJtVVbtZ0Nb+TKsW2+9VWRjxozxOldrLNSafesQTYAAAOAICgAAAAyiAAAAwCAKAAAADDKzHfCXX36p5lrziNY498ILL4hsxIgRItOaXqqzsprWWKU1S7333nsiu++++0Q2bdo09T44tngNsqWlpSLTGvS0LX379+/vdR8t07Y8dU6f29r80q6pNSZ17txZvY/Puc4599BDD4mssrJSZMuWLROZ1jwVdnVN6A4ePCgybaW8Q4cOeV1PawJ86aWXRDZjxgz1fO3nyncea8f17dtXZJs2bRKZ1mTtnHMdOnRQ86NpX7eWpSLeAAAAYBAFAAAABlEAAABgEAUAAAAGmWkCnDdvnprffffdIrvppptE1rt3b5HVqyfrp7ANS61atfK6T5MmTUT29NNPi4wmwMQdPnxYzd96660avU/Y7Xe1+eFLa/Dy3cY4noULF4pswYIFImvXrp3ItK1VCwsLQ40HOm3e+Tb8abSGaq3RsHHjxur5Wq41ZF966aUiu+6660SmPSPvuusukWlbeTvnXG5urpofTVsRtlGjRl7nJhtvAAAAMIgCAAAAgygAAAAwiAIAAACDKAAAADAoLT8FoHW3tmjRQj328ccfF1n9+vW9sjDiLU186qmnimz//v0i05aaHDx4cPiBIWVpXc3O6Z9WKC8vr+3huI8++kjNtfFon27RuqdzcnLCDwxetK57bflcX9onV4qKirzu65xzF198sci0pbPjnX80bflpbXnuvXv3el0vnoKCglDnJxNvAAAAMIgCAAAAgygAAAAwiAIAAACD0rIJUGv+iLe3u7aMqtbwpzVVaQ0uzz//vMg+/fRTkWnL9sa7d1ZWlsjOPvtskd13333qNZGYeEvsduvWrdbvrS1Dqs1r5/QG13jz/Whaw97WrVu9rvfkk0+q1+zRo4fI2rZtKzKtkbVhw4bqNVHzfJ81YZxzzjkiGzRokHpsTTeAlpWViezNN98UWbyflUWLFolM+9nXGgujgjcAAAAYRAEAAIBBFAAAABhEAQAAgEHR7V74DlVVVSKLt5Kfln/88ccimzJlisjef/99kWmrYZWWlnod55y+ylWfPn1E9uyzz4osMzNTvSYSE2/OaE1u6UT7+rRG1hNPPFE9Py8vT2SdO3cWWdOmTUUW5YYq+Kmr1R61xm1t1T+t4dY550aOHCmyli1biizeszwKeAMAAIBBFAAAABhEAQAAgEEUAAAAGJSWHTdaI1G8Vd00DzzwgMg++OADr3O11ae0LN7Wrto4fVd1A2rLhx9+6H2sttX1BRdcILJ+/fqJTFsdEEjEihUrRKZtra6thumcvlX8WWedJbKbb75ZZL5bFicbbwAAADCIAgAAAIMoAAAAMIgCAAAAg9KyCbA6DX+aG264QWTz588XmdY8Mnr0aJH16tVLZFdffbV67y5duohs5cqVItu5c6fIOnbsqF4TqA6tiW/Tpk0i27Bhg3r+SSedJLJXX31VZNoWwSUlJR4jdC4/P19knTp1Uo+NSkMWEqfN2QcffFBk8bbU1ixZskRk7du3F5nW5B2VOccbAAAADKIAAADAIAoAAAAMogAAAMCgGmkCXLRokcjirRx22WWXiaxVq1Y1MYwac8YZZ4hs1apVItMalrRGPG11s+LiYvXe2lbGYZsagXi2bNkismeeecbr3N27d6v5iy++KLJu3bqJTGuo0mRlZYmsWbNmIou3fXP37t297oOapzVKz5o1Sz1W+z2ibbWrrYyqPYvnzp3rM8S4tG2LtS2C420nHAX8ZgEAwCAKAAAADKIAAADAIAoAAAAMqpEmQK0xQmvecM65k08+WWSp1gSo0VYtKy0tFdlXX30lsq1bt4rs4osvVu+jNfw1bdpUZPG2Ewbi0VYs++tf/+p1rjbf1q5dqx6rrWY5adIkkWnPDa3hT2vG0rb8RurRmp0XL16sHqv9zoj3e+RoGzdu9DpOu96wYcPUYx999FGRHX/88V73iQreAAAAYBAFAAAABlEAAABgEAUAAAAGUQAAAGBQnbfSTp06VWRnnnmmyIYOHSoyrYOzdevWItOW462oqFDH88knn4hMW2rynXfeEdkbb7whsuXLl4vMt5M1niuvvFJk2n7owLe0T6i8/vrrItu1a5fItJ+pU045RWTxlvLVPjGj/Ywj/TVu3Fhk/fr18z7/+9//vsi05dLPPfdcr+tpz/Zp06apx/bs2dPrmlHGGwAAAAyiAAAAwCAKAAAADKIAAADAoEBrivgv3/mH35o/f77Ixo4dm+CQ/n1jZVxaM53WZKI1MZWXl6v32bZtm8i0Bioty87OFpnvEqWDBw9W83vvvVdkvXv3Tvg+SRau+7FmeM3hKDt06JDI/va3v4lsxYoVIqtfv77IbrnlFpGtXr1aZHv27FHHoy0F3KtXL/XYCEiFOeycgXnsa//+/SIbMGCAyNatWyeys846S2SzZ89W79OsWbMERpey1HnMGwAAAAyiAAAAwCAKAAAADKIAAADAoBrpJBs+fLjINm/erB6r7bH81FNPiWzv3r0i0xqWtAYobW9obfUo55yrrKwUWW5urld2/vnni2zUqFEi01ZRA6pLa0R1zrkZM2aIbN++fSLTmlYvvfRSkfk2P8Vb4bJ58+Ze5wOJ0Fav9JWXlyeynJycMMOJNN4AAABgEAUAAAAGUQAAAGAQBQAAAAbV2nJyWsORc87dddddIrvuuutEtn79+hofk6akpERk2naVlhtFkBq0lc2c0xv+tJUitdU5CwoKRHb48GGvLJ4WLVp4Hwt8F60ZfPfu3Qlfb+DAgSKrV8/uvwfb/coBADCMAgAAAIMoAAAAMIgCAAAAg1JiT1ltdSYtC0Nb8S9eHpGtdpHG1q5dK7I5c+aox2orZE6YMEFkbdu29bq3ttKatgVrPI0aNfI+FviWtgW81gR4jC3s/0N7jl9yySXVH1ga4w0AAAAGUQAAAGAQBQAAAAZRAAAAYFBadrtpTSLxmgC1BirLK0Ohdu3cuVNkmzZtEtkbb7whsnjNqePGjROZb8Ofxne71VatWiV8D+BoW7ZsEZm26t9XX33llXXv3l1kjRs3TnB06YnfdAAAGEQBAACAQRQAAAAYRAEAAIBBadkEWFFRIbJ4jX1aEyBQE0pLS0X20ksviUxbZU9r+Lv44ovV+3Tp0iWB0R3h2zCrHRem0RA4mtYgq9G2cC8rKxNZ586dQ48p3fEGAAAAgygAAAAwiAIAAACDKAAAADCIAgAAAIMi/ykAreM/CAKRxVtGFagtWnf/j3/8Y69zc3NzRZaXlxd6TEc7cOCAyLZt2yayJk2aiKygoKDGxwMcy3HHHSeyHj16iExbIjsnJ6dWxhRVvAEAAMAgCgAAAAyiAAAAwCAKAAAADIp8Z5y2RGlGRkYSRgL8r9atWyd7CMekNfxpCgsLRUZjLWpS165dRbZ69WqRac2wH3zwgciaNm1aMwNLY7wBAADAIAoAAAAMogAAAMAgCgAAAAyKfBdPw4YNkz0EIBKqqqpEpq1WqMnOzq7p4QD/Q1v98rTTTkvCSOzgDQAAAAZRAAAAYBAFAAAABlEAAABgUKCtpAcAANIbbwAAADCIAgAAAIMoAAAAMIgCAAAAgygAAAAwiAIAAACDKAAAADCIAgAAAIMoAAAAMIgCAAAAgygAAAAwiAIAAACDKAAAADCIAgAAAIMoAKopCILio/6pDIJgerLHBVQH8xhRxxwOr0GyBxA1sVgs+9v/HQRBlnNup3NuVvJGBFQf8xhRxxwOjzcA4VzgnNvlnHsv2QMBQmAeI+qYwwmgAAjnSufc87FYLJbsgQAhMI8RdczhBAR8vxITBEEH59xG51zXWCy2MdnjARLBPEbUMYcTxxuAxF3hnFvEhEPEMY8RdczhBFEAJO4K59zMZA8CCIl5jKhjDieI/wSQgCAITnXOzXfOtYnFYgeSPR4gEcxjRB1zOBzeACTmSufcHCYcIo55jKhjDofAGwAAAAziDQAAAAZRAAAAYBAFAAAABlEAAABg0LE2A6JDEGEEyR6AYw4jnFSYw84xjxGOOo95AwAAgEEUAAAAGEQBAACAQRQAAAAYRAEAAIBBFAAAABhEAQAAgEEUAAAAGEQBAACAQRQAAAAYRAEAAIBBFAAAABhEAQAAgEEUAAAAGEQBAACAQRQAAAAYRAEAAIBBFAAAABhEAQAAgEEUAAAAGEQBAACAQRQAAAAYRAEAAIBBFAAAABhEAQAAgEENkj0AC/7whz+I7Pjjj1ePPeWUU2p7OAAA8AYAAACLKAAAADCIAgAAAIMoAAAAMChSTYAlJSUia9Sokcj27NkjsuXLl6vXLCws9Lqm5r333hOZNsaVK1eKrGfPnuo1Dx8+LLIGDSL114SI27x5s8gefvhhkV111VXq+X369PG6j/azsmHDBq9zO3fuLLLGjRt7nQvgCN4AAABgEAUAAAAGUQAAAGAQBQAAAAZFqrvsoYceElleXp7ItGagpUuXqtfMz88X2axZs0RWVlYmsp07d4qsW7duXufOnz9fHY+Wa42K9evXV88HqkNrUB04cKDISktLRbZlyxb1mt/73vdE9swzz4hMawL8+uuv1Wsebdq0aSKbNGmS17mIjqKiIpHt3btXZG+99ZbIysvLRaY9s6dOneo9nuHDh4tsypQpIsvNzRWZ1vidkZHhfe/awBsAAAAMogAAAMAgCgAAAAyiAAAAwKAgFot9159/5x/WNW2VsOeee05kH3/8scjefvtt9Zo/+MEPRDZgwACRNWzYUGTFxcUiW7Nmjch2794tskGDBqnjGTp0qMi0RseuXbuKLCsrS71mEgXJHoBLsTlcVyorK0W2anqdCYoAAAo1SURBVNUqkU2YMEFk8RpmjxYE4f56tWeP7zWHDBkisgULFoQaTxypMIedq4V5/O6774pMW0VVs3DhQpFpTW7asy/ePbSGVG21VW1uh5lLYT322GMi6927t8hatWolsh49etTKmBTqN4M3AAAAGEQBAACAQRQAAAAYRAEAAIBBkWoC1FRUVIhMW2Hs+eefV8+/6KKLRKY10zVp0kRkWlPijBkzvK53++23q+PRVlxbtmyZ13i05sUkS4UGqpSfw7VB27538uTJIvNtnqqNJivfa7Zs2VJkmzZtElktbQecCnPYuRDzON7qivfee2/Cgwlj3bp1aj5v3jyRHeP3039oK7pqTXdt27YVWbyG7F/96lde99aO05pUmzZtKrK+fft63aMG0AQIAACOoAAAAMAgCgAAAAyiAAAAwKBIbQes0bZT1LIbb7yxLoaj0lbya9SokXpsZmamyLSGP9ijbSvtnHMXXHCByObOnZvwfbTGK20lzfXr16vna6sL+t5n9OjRInvxxRdFpv2cQBevOVJbue+cc87xuua+fftEpq36p83ZwYMHq9f85S9/6XXvmqatkFkdWsOfRmtmTTbeAAAAYBAFAAAABlEAAABgEAUAAAAGUQAAAGBQ5D8FEAVaF268ZVS3bt0qsqKiIpGlYkcpas7OnTtFdtttt6nHvvHGG17X7Nixo8gmTZoksjFjxohs7969IrvqqqvU+/guEax9Muf+++8XGR3/4cT7FMC1117rdb62PHlVVZXItOdUvE871TTta2zTpo1XdujQIfWa/fv3F9nHH3+cwOiO0L5nycYbAAAADKIAAADAIAoAAAAMogAAAMAgmgBD2LBhg9dxAwYMENnmzZvVY7V9znNyckTWq1cvr3sjmnbt2iWyP/3pT97nFxQUiOy1114TmTaPvvzyS5FddNFFXsc5p+/Dfvfdd4ts/PjxIqPhL7kOHjwostWrV4tMawysDVoTYfv27UWmLbfuO5cWLFig5suXL/c6X9OggfzVmp+fn/D1agtvAAAAMIgCAAAAgygAAAAwiAIAAACDaAIMYc+ePSLTVnv64IMPRNapUyf1mlrDn7Zvd0ZGhscIEVX79+8X2ciRI9Vjf/jDH4rs+uuv97rPU089JTJthbhYLCayeCv+aaultW3b1ms8qDvaCnirVq0SWVlZWY3et3Xr1mqem5srMq25L8yzT/u5+uMf/6geW1lZKbKzzjpLZNrXozXhxluRMZl4AwAAgEEUAAAAGEQBAACAQRQAAAAYRBNgCNpKaLNnzxbZiBEjRDZw4ED1mj169BAZDX/2DBo0yCuLp6SkRGQ/+9nPRKY1AWpuvfVWkd1xxx3qsdpKgEg9WsNyTW9Zq61+pzXIOac3QIdx+PBhkVVUVIhs+/bt3tfU5nZ5ebnIsrOzva+ZTLwBAADAIAoAAAAMogAAAMAgCgAAAAyiCVChNcJ8+umnIjtw4IDIduzYITKteat3797qveOtrgbEo23N+sQTT4js6aefFpm23eqJJ54osilTpoisefPmvkNECtIa1bS/U21ram3eaA1/HTt2FFn9+vV9hxiKtrXx1KlTQ11z8uTJItMat6PyHOcNAAAABlEAAABgEAUAAAAGUQAAAGAQTYCKlStXimzfvn0i0xphbr75ZpFpDX9RaRJB6tBW93POuVGjRols3rx5ItPm3NixY0U2c+bMBEaHdNClSxeRac19TZo0EVkyt7vVGv7Wrl0rsunTp3tfs1evXiI77rjjRBbllVp5AwAAgEEUAAAAGEQBAACAQRQAAAAYZL4JUFvNb+HChSLbu3evyBo08Pv2vfrqqyKbO3eueqzW0NW9e3ev+yB9aA1M48aNU49dsWKF1zW1hj/f7YBhg9bQ1qJFiySMJD6t4W/p0qUi2717t8hisZj3fX7605+KrKa3LE423gAAAGAQBQAAAAZRAAAAYBAFAAAABlEAAABgUFp+CqCqqkpkc+bMUY/VOqi1vdS/+eYbkWn7pg8dOlRkWteqtpSmc84VFBSoOdLX6tWrRfaTn/xEZMuXL/e+5oABA0T27LPPiqxhw4be1wTqWmlpqci2bdvmde6TTz4psnr15L/znnLKKer5t9xyi9d9oow3AAAAGEQBAACAQRQAAAAYRAEAAIBBadkEuH//fpH961//8j5//PjxInv55ZdFpi0/2bZtW5ENGzZMZBMnTlTvnZmZ6TNERNTmzZtFNnDgQJGVlJSILAgC9Zp33nmnyO655x6R+S5dDSSDtkzv559/LrJ9+/Z5XW/BggUi05qvL7/8cvX8eD9v6YQ3AAAAGEQBAACAQRQAAAAYRAEAAIBBadkVlJ2dLTKtES+eXr16ieyGG24Q2bRp00Q2c+ZMkZ1++ukio9kv/b399tsi05o/Dx06JLJGjRqJ7LnnnlPvM2rUKJHR8IdUpa2M6pxzO3bsEJlvw9+MGTNE1rp1a5FdeOGFIhs7dqzXPdIRbwAAADCIAgAAAIMoAAAAMIgCAAAAg9KyUygjI0Nkw4cPr/H7/OhHPxLZ448/LrI1a9bU+L2RWrQV/h544AGRbdy40et6l112mcguuuii6g8MSDHa1urOObd169aEr/nSSy+JTPt5admypcjy8vISvm/U8QYAAACDKAAAADCIAgAAAIMoAAAAMCgtmwCTqV49WVPl5uYmYSSoS1dffbXI5s2b53Wutg3qmWeeGXpMQLJt27ZNZNqKf9XxyiuviGzkyJEi07b+1Rq3LeMNAAAABlEAAABgEAUAAAAGUQAAAGBQ5JsAtQaqyspKkWnNefHyw4cPi6yoqEhkH374ocgaN24ssjZt2qj3RjRdfvnlInvrrbdEFgSB1/U2bNggsnbt2lV/YEASlZWViWz79u0ii7cdsK/p06eLrLCwUGRdunQR2QknnBDq3umGNwAAABhEAQAAgEEUAAAAGEQBAACAQZFvAiwtLRXZkiVLRJaTk6Oen5WVJTJt9aqqqiqR9evXT2RjxozxvjeiSdu2VJsfnTp1Etmjjz7qdRwQNeXl5SIL2/Cnrfqn6dGjh8iuueaaUPe2gDcAAAAYRAEAAIBBFAAAABhEAQAAgEEUAAAAGBRoS+n+l+/8w1S1f/9+ka1bt8772BYtWogsNzdXZG3bthVZZmamzxCt8FsLt3bV+Bx+4oknRLZ+/XqR3XPPPSJr2rRpTQ8HtSsV5rBzEXgWa0tab9myJdQ1J0yYILITTzxRZK1atRKZ9okbw9R5zBsAAAAMogAAAMAgCgAAAAyiAAAAwKC0bAJEykiFBirmMMJIhTnsXATmsbYU8Jo1a0RWVFTkfc2ysjKRvf766yLTlmAfPHiw930MoAkQAAAcQQEAAIBBFAAAABhEAQAAgEE0AaI2pUIDFXMYYaTCHHaOeYxwaAIEAABHUAAAAGAQBQAAAAZRAAAAYNCxmgABAEAa4g0AAAAGUQAAAGAQBQAAAAZRAAAAYBAFAAAABlEAAABg0P8DJwmTlFcjsWoAAAAASUVORK5CYII=\n",
      "text/plain": [
       "<Figure size 648x648 with 9 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "mixup = MixUp(0.5)\n",
    "with Learner(dls, nn.Linear(3,4), loss_func=CrossEntropyLossFlat(), cbs=mixup) as learn:\n",
    "    learn.epoch,learn.training = 0,True\n",
    "    learn.dl = dls.train\n",
    "    b = dls.one_batch()\n",
    "    learn._split(b)\n",
    "    learn('before_batch')\n",
    "\n",
    "_,axs = plt.subplots(3,3, figsize=(9,9))\n",
    "dls.show_batch(b=(mixup.x,mixup.y), ctxs=axs.flatten())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Export -"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Converted 00_torch_core.ipynb.\n",
      "Converted 01_layers.ipynb.\n",
      "Converted 02_data.load.ipynb.\n",
      "Converted 03_data.core.ipynb.\n",
      "Converted 04_data.external.ipynb.\n",
      "Converted 05_data.transforms.ipynb.\n",
      "Converted 06_data.block.ipynb.\n",
      "Converted 07_vision.core.ipynb.\n",
      "Converted 08_vision.data.ipynb.\n",
      "Converted 09_vision.augment.ipynb.\n",
      "Converted 09b_vision.utils.ipynb.\n",
      "Converted 09c_vision.widgets.ipynb.\n",
      "Converted 10_tutorial.pets.ipynb.\n",
      "Converted 11_vision.models.xresnet.ipynb.\n",
      "Converted 12_optimizer.ipynb.\n",
      "Converted 13_callback.core.ipynb.\n",
      "Converted 13a_learner.ipynb.\n",
      "Converted 13b_metrics.ipynb.\n",
      "Converted 14_callback.schedule.ipynb.\n",
      "Converted 14a_callback.data.ipynb.\n",
      "Converted 15_callback.hook.ipynb.\n",
      "Converted 15a_vision.models.unet.ipynb.\n",
      "Converted 16_callback.progress.ipynb.\n",
      "Converted 17_callback.tracker.ipynb.\n",
      "Converted 18_callback.fp16.ipynb.\n",
      "Converted 18a_callback.training.ipynb.\n",
      "Converted 19_callback.mixup.ipynb.\n",
      "Converted 20_interpret.ipynb.\n",
      "Converted 20a_distributed.ipynb.\n",
      "Converted 21_vision.learner.ipynb.\n",
      "Converted 22_tutorial.imagenette.ipynb.\n",
      "Converted 23_tutorial.vision.ipynb.\n",
      "Converted 24_tutorial.siamese.ipynb.\n",
      "Converted 24_vision.gan.ipynb.\n",
      "Converted 30_text.core.ipynb.\n",
      "Converted 31_text.data.ipynb.\n",
      "Converted 32_text.models.awdlstm.ipynb.\n",
      "Converted 33_text.models.core.ipynb.\n",
      "Converted 34_callback.rnn.ipynb.\n",
      "Converted 35_tutorial.wikitext.ipynb.\n",
      "Converted 36_text.models.qrnn.ipynb.\n",
      "Converted 37_text.learner.ipynb.\n",
      "Converted 38_tutorial.text.ipynb.\n",
      "Converted 39_tutorial.transformers.ipynb.\n",
      "Converted 40_tabular.core.ipynb.\n",
      "Converted 41_tabular.data.ipynb.\n",
      "Converted 42_tabular.model.ipynb.\n",
      "Converted 43_tabular.learner.ipynb.\n",
      "Converted 44_tutorial.tabular.ipynb.\n",
      "Converted 45_collab.ipynb.\n",
      "Converted 46_tutorial.collab.ipynb.\n",
      "Converted 50_tutorial.datablock.ipynb.\n",
      "Converted 60_medical.imaging.ipynb.\n",
      "Converted 61_tutorial.medical_imaging.ipynb.\n",
      "Converted 65_medical.text.ipynb.\n",
      "Converted 70_callback.wandb.ipynb.\n",
      "Converted 71_callback.tensorboard.ipynb.\n",
      "Converted 72_callback.neptune.ipynb.\n",
      "Converted 73_callback.captum.ipynb.\n",
      "Converted 74_callback.cutmix.ipynb.\n",
      "Converted 97_test_utils.ipynb.\n",
      "Converted 99_pytorch_doc.ipynb.\n",
      "Converted index.ipynb.\n",
      "Converted tutorial.ipynb.\n"
     ]
    }
   ],
   "source": [
    "#hide\n",
    "from nbdev.export import notebook2script\n",
    "notebook2script()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "jupytext": {
   "split_at_heading": true
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
